Compare commits

...

3 Commits

Author SHA1 Message Date
c9f7963aeb [dynamo][user-defined] Simplify and improve scope of UserDefinedObject var_getattr
ghstack-source-id: 3ae5569e3914050c7fd2d43b943622f6c5d93c5a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130169
2024-07-06 17:18:23 -07:00
b774deb215 [dynamo][user-defined] Support method descriptors
Fixes https://github.com/pytorch/pytorch/issues/120650

ghstack-source-id: a26c5df2a625804f473eaea9b8bea00526f19cc4
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130159
2024-07-05 14:45:45 -07:00
2de5e650db [dynamo] Validate check_fn
ghstack-source-id: b698395a184ed8264dcb59ae143fe588df95afef
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118448

random
2024-07-05 07:50:30 -07:00
14 changed files with 190 additions and 80 deletions

View File

@ -3428,8 +3428,7 @@ def forward(self, x):
example_inputs = (torch.rand(5),)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
RuntimeError, "Unmatched number of outputs from cond"
):
torch._dynamo.export(
f_mismatch_return_length,

View File

@ -10641,6 +10641,37 @@ fn
res = opt_fn(x)
self.assertEqual(ref, res)
def test_descriptor(self):
class lazy_property:
def __init__(self, wrapped):
self.wrapped = wrapped
def __get__(self, instance, obj_type=None):
value = self.wrapped(instance)
setattr(instance, self.wrapped.__name__, value)
return value
class UserDefined:
def __init__(self):
self.a = 3
@lazy_property
def length(self):
return 3
def run(self, x):
return x * self.length
obj = UserDefined()
def fn(x):
return obj.run(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(fn(x), opt_fn(x))
def test_assert_size_stride(self):
x = torch.randn(2, 3, 4)
with self.assertRaisesRegex(

View File

@ -1363,6 +1363,9 @@ s1 > 3""",
s = SubTensor(torch.randn(3, 10))
f(s)
# Guard validation upsets the guard
# https://github.com/pytorch/pytorch/issues/129936
@unittest.expectedFailure
def test_recompile_with_symbool_inputs(self):
def f(pred: bool):
if pred:

View File

@ -3601,6 +3601,8 @@ def forward(self, x):
):
torch.export.export(exported_v2.module(), (torch.randn(2, 2),))
# https://github.com/pytorch/pytorch/issues/129939
@testing.expectedFailureNonStrict
def test_export_cond(self):
class A(torch.nn.Module):
def __init__(self):
@ -4976,6 +4978,9 @@ graph():
)
)
# Guard validation upsets the guard
# https://github.com/pytorch/pytorch/issues/129939
@unittest.expectedFailure
def test_cond_with_module_stack_export_with(self):
class Bar(torch.nn.Module):
def __init__(self):

View File

@ -1411,8 +1411,8 @@ def forward(self, arg0_1):
x = torch.randn(4)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
torch._dynamo.exc.CondOpArgsMismatchError,
"Expected to return same number of outputs but got:",
):
make_fx(f)(x, torch.tensor(False))
@ -1584,8 +1584,8 @@ def forward(self, arg0_1):
x = torch.randn(4)
with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError,
"Cond doesn't work unless it is captured completely with torch.compile",
torch._dynamo.exc.CondOpArgsMismatchError,
"Expected to return same number of outputs but got:",
):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))

View File

@ -2346,6 +2346,7 @@ known_failing_tests = {
"test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
"test_unpack_hooks_exec_count", # pack/unpack saved tensor hooks firing more than once
"test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
}
if not HAS_CUDA:

View File

@ -2103,6 +2103,7 @@ class CheckFunctionManager:
guard.create(builder)
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
# Keep track of weak references of objects with ID_MATCH guard. This
# info is stored alongside optimized_code and check_fn and is used to
# limit the number of cache entries with same ID_MATCH'd object.
@ -2123,6 +2124,18 @@ class CheckFunctionManager:
self.guard_manager.id_matched_objs = builder.id_matched_objs
self.check_fn = self.guard_manager
# Check that the guard returns True. False means that we will always
# recompile.
# TODO(anijain2305, ydwu4) - Skipping export because of following test
# python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs
if not output_graph.export:
if not self.guard_manager.check(output_graph.local_scope):
reasons = get_guard_fail_reason_helper(
self.guard_manager, # type: ignore[arg-type]
output_graph.local_scope,
)
raise AssertionError(f"Guard check failed: {reasons}")
# NB - We have to very careful of cleaning up here. Because of the
# invalidate function, we can create a weakref finalizer that keeps
# `self` alive for very long. Sometimes by mistake, we can run
@ -2456,9 +2469,8 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
return [f"Duplicate tensors found: {reason}"]
def get_guard_fail_reason(
def get_guard_fail_reason_helper(
guard_fn: GuardFn,
code: types.CodeType,
f_locals: Dict[str, object],
) -> str:
"""
@ -2525,6 +2537,15 @@ def get_guard_fail_reason(
break
reason_str = "\n".join(reasons)
return reason_str
def get_guard_fail_reason(
guard_fn: GuardFn,
code: types.CodeType,
f_locals: Dict[str, object],
) -> str:
reason_str = get_guard_fail_reason_helper(guard_fn, f_locals)
guard_failures[orig_code_map[code]].append(reason_str)
try:

View File

@ -587,9 +587,16 @@ def is_wrapper_or_member_descriptor(value):
return isinstance(
value,
(
types.MethodWrapperType,
# set up by PyGetSetDef
types.GetSetDescriptorType,
# set by PyMethodDef, e.g. list.append
types.MethodDescriptorType,
# slots - list.__add__
types.WrapperDescriptorType,
# set up by PyMemberDef
types.MemberDescriptorType,
# wrapper over C functions
types.MethodWrapperType,
),
)

View File

@ -2538,6 +2538,22 @@ class SourcelessBuilder:
handlers[immutable_list] = handlers[list]
handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value)
handlers[
torch.distributions.constraints._Real
] = lambda tx, value: UserDefinedObjectVariable(
value, mutable_local=MutableLocal()
)
handlers[
torch.distributions.constraints._Interval
] = lambda tx, value: UserDefinedObjectVariable(
value, mutable_local=MutableLocal()
)
handlers[
torch.distributions.constraints.Constraint
] = lambda tx, value: UserDefinedObjectVariable(
value, mutable_local=MutableLocal()
)
def passthrough(tx, value):
return value

View File

@ -4,11 +4,9 @@ import collections
import contextlib
import enum
import functools
import importlib
import inspect
import itertools
import random
import re
import sys
import threading
import types
@ -45,13 +43,13 @@ from ..source import (
WeakRefCallSource,
)
from ..utils import (
all_hook_names,
build_checkpoint_variable,
check_constant_args,
get_custom_getattr,
has_torch_function,
is_namedtuple_cls,
is_utils_checkpoint,
is_wrapper_or_member_descriptor,
istype,
namedtuple_fields,
object_has_getattribute,
@ -829,7 +827,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
def var_getattr(self, tx, name):
from .. import trace_rules
from . import ConstantVariable
from .builder import VariableBuilder
value = self.value
source = AttrSource(self.source, name) if self.source else None
@ -842,6 +839,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
options = {"source": source}
return variables.GetAttrVariable(self, name, **options)
# TODO(anijain2305) - Investigate if we need specialization for more
# dunder attrs. inspect.getattr_static does not return correct value for
# them.
if name == "__class__":
options = {"source": source}
return UserDefinedClassVariable(type(self.value), **options)
try:
subobj = self._getattr_static(name)
except AttributeError:
@ -891,11 +895,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return variables.UserMethodVariable(
subobj.fget, self, source=source
).call_function(tx, [], {})
elif isinstance(subobj, torch.distributions.utils.lazy_property):
subobj_var = UserDefinedObjectVariable(subobj, source=source)
return variables.UserMethodVariable(
subobj.__get__.__func__, subobj_var, source=source
).call_function(tx, [self], {})
elif isinstance(subobj, staticmethod):
func = subobj.__get__(self.value)
if source is not None:
@ -906,6 +905,25 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return variables.UserMethodVariable(
subobj.__func__, self.var_getattr(tx, "__class__"), source=source
)
elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
subobj.__get__
):
# Attribute has a __get__ method. Create a user defined object vt
# for the subobj, and then trace the __get__ method.
descriptor_var = UserDefinedObjectVariable(subobj, source=source)
get_source = self.source
if self.source:
get_source = AttrSource(self.source, "__get__")
# The arguments of the __get__ function are (self, instance, owner)
# self - descriptor_var
# instance - instance of the class, represented by self here
# owner - class object
owner_var = UserDefinedClassVariable(type(self.value))
return variables.UserMethodVariable(
subobj.__get__.__func__, descriptor_var, source=get_source
).call_function(tx, [descriptor_var, self, owner_var], {})
elif isinstance(subobj, types.FunctionType) or (
isinstance(subobj, types.MethodType)
and isinstance(self.value, torch.nn.Module)
@ -943,77 +961,86 @@ class UserDefinedObjectVariable(UserDefinedVariable):
else:
return trace_rules.lookup(func)(func)
if (
name in getattr(value, "__dict__", {})
or ConstantVariable.is_literal(subobj)
or isinstance(
subobj,
(
torch.Tensor,
torch.nn.Module,
re.Pattern,
),
)
):
if subobj is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor(subobj):
if source:
return variables.LazyVariableTracker.create(subobj, source)
elif ConstantVariable.is_literal(subobj):
return ConstantVariable.create(subobj)
elif (
type(subobj) == torch.utils._pytree.TreeSpec
or type(subobj) == torch.utils._pytree.LeafSpec
or type(value) == torch.utils._pytree.TreeSpec
):
else:
from .builder import SourcelessBuilder
return SourcelessBuilder.create(tx, subobj)
if (
subobj is not NO_SUCH_SUBOBJ
and name not in getattr(value, "__dict__", {})
and (
type(value).__module__.startswith("torch.")
or isinstance(subobj, re.Pattern)
)
and "torch.optim" not in type(value).__module__
and not callable(value)
and not isinstance(subobj, types.MethodDescriptorType)
):
if not source:
assert getattr(
importlib.import_module(type(value).__module__),
type(value).__name__,
) is type(value)
source = AttrSource(
AttrSource(
tx.import_source(type(value).__module__), type(value).__name__
),
name,
)
return VariableBuilder(tx, source)(subobj)
options = {"source": source}
if isinstance(
subobj,
(
torch.distributions.constraints._Interval,
torch.distributions.constraints._Real,
torch.distributions.constraints.Constraint,
),
):
return UserDefinedObjectVariable(subobj, **options)
elif isinstance(self.value, torch.nn.Module) and name in all_hook_names:
assert isinstance(subobj, collections.OrderedDict)
if not subobj:
return variables.ConstDictVariable(
subobj, collections.OrderedDict, **options
)
if name == "__class__":
return UserDefinedClassVariable(type(self.value), **options)
return variables.GetAttrVariable(self, name, **options)
# if (
# name in getattr(value, "__dict__", {})
# or ConstantVariable.is_literal(subobj)
# or isinstance(
# subobj,
# (
# torch.Tensor,
# torch.nn.Module,
# re.Pattern,
# ),
# )
# ):
# if source:
# return variables.LazyVariableTracker.create(subobj, source)
# elif ConstantVariable.is_literal(subobj):
# return ConstantVariable.create(subobj)
# elif (
# type(subobj) == torch.utils._pytree.TreeSpec
# or type(subobj) == torch.utils._pytree.LeafSpec
# or type(value) == torch.utils._pytree.TreeSpec
# ):
# from .builder import SourcelessBuilder
# return SourcelessBuilder.create(tx, subobj)
# if (
# subobj is not NO_SUCH_SUBOBJ
# and name not in getattr(value, "__dict__", {})
# and (
# type(value).__module__.startswith("torch.")
# or isinstance(subobj, re.Pattern)
# )
# and "torch.optim" not in type(value).__module__
# and not callable(value)
# and not isinstance(subobj, types.MethodDescriptorType)
# ):
# if not source:
# assert getattr(
# importlib.import_module(type(value).__module__),
# type(value).__name__,
# ) is type(value)
# source = AttrSource(
# AttrSource(
# tx.import_source(type(value).__module__), type(value).__name__
# ),
# name,
# )
# return VariableBuilder(tx, source)(subobj)
# options = {"source": source}
# if isinstance(
# subobj,
# (
# torch.distributions.constraints._Interval,
# torch.distributions.constraints._Real,
# torch.distributions.constraints.Constraint,
# ),
# ):
# return UserDefinedObjectVariable(subobj, **options)
# elif isinstance(self.value, torch.nn.Module) and name in all_hook_names:
# assert isinstance(subobj, collections.OrderedDict)
# if not subobj:
# return variables.ConstDictVariable(
# subobj, collections.OrderedDict, **options
# )
# if name == "__class__":
# return UserDefinedClassVariable(type(self.value), **options)
# return variables.GetAttrVariable(self, name, **options)
def call_hasattr(self, tx, name: str) -> "VariableTracker":
if tx.output.side_effects.is_attribute_mutation(self):
try: