Compare commits

...

3 Commits

Author SHA1 Message Date
c9f7963aeb [dynamo][user-defined] Simplify and improve scope of UserDefinedObject var_getattr
ghstack-source-id: 3ae5569e3914050c7fd2d43b943622f6c5d93c5a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130169
2024-07-06 17:18:23 -07:00
b774deb215 [dynamo][user-defined] Support method descriptors
Fixes https://github.com/pytorch/pytorch/issues/120650

ghstack-source-id: a26c5df2a625804f473eaea9b8bea00526f19cc4
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130159
2024-07-05 14:45:45 -07:00
2de5e650db [dynamo] Validate check_fn
ghstack-source-id: b698395a184ed8264dcb59ae143fe588df95afef
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118448

random
2024-07-05 07:50:30 -07:00
14 changed files with 190 additions and 80 deletions

View File

@ -3428,8 +3428,7 @@ def forward(self, x):
example_inputs = (torch.rand(5),) example_inputs = (torch.rand(5),)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError, RuntimeError, "Unmatched number of outputs from cond"
"Cond doesn't work unless it is captured completely with torch.compile",
): ):
torch._dynamo.export( torch._dynamo.export(
f_mismatch_return_length, f_mismatch_return_length,

View File

@ -10641,6 +10641,37 @@ fn
res = opt_fn(x) res = opt_fn(x)
self.assertEqual(ref, res) self.assertEqual(ref, res)
def test_descriptor(self):
class lazy_property:
def __init__(self, wrapped):
self.wrapped = wrapped
def __get__(self, instance, obj_type=None):
value = self.wrapped(instance)
setattr(instance, self.wrapped.__name__, value)
return value
class UserDefined:
def __init__(self):
self.a = 3
@lazy_property
def length(self):
return 3
def run(self, x):
return x * self.length
obj = UserDefined()
def fn(x):
return obj.run(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(fn(x), opt_fn(x))
def test_assert_size_stride(self): def test_assert_size_stride(self):
x = torch.randn(2, 3, 4) x = torch.randn(2, 3, 4)
with self.assertRaisesRegex( with self.assertRaisesRegex(

View File

@ -1363,6 +1363,9 @@ s1 > 3""",
s = SubTensor(torch.randn(3, 10)) s = SubTensor(torch.randn(3, 10))
f(s) f(s)
# Guard validation upsets the guard
# https://github.com/pytorch/pytorch/issues/129936
@unittest.expectedFailure
def test_recompile_with_symbool_inputs(self): def test_recompile_with_symbool_inputs(self):
def f(pred: bool): def f(pred: bool):
if pred: if pred:

View File

@ -3601,6 +3601,8 @@ def forward(self, x):
): ):
torch.export.export(exported_v2.module(), (torch.randn(2, 2),)) torch.export.export(exported_v2.module(), (torch.randn(2, 2),))
# https://github.com/pytorch/pytorch/issues/129939
@testing.expectedFailureNonStrict
def test_export_cond(self): def test_export_cond(self):
class A(torch.nn.Module): class A(torch.nn.Module):
def __init__(self): def __init__(self):
@ -4976,6 +4978,9 @@ graph():
) )
) )
# Guard validation upsets the guard
# https://github.com/pytorch/pytorch/issues/129939
@unittest.expectedFailure
def test_cond_with_module_stack_export_with(self): def test_cond_with_module_stack_export_with(self):
class Bar(torch.nn.Module): class Bar(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -1411,8 +1411,8 @@ def forward(self, arg0_1):
x = torch.randn(4) x = torch.randn(4)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.exc.CondOpArgsMismatchError,
"Cond doesn't work unless it is captured completely with torch.compile", "Expected to return same number of outputs but got:",
): ):
make_fx(f)(x, torch.tensor(False)) make_fx(f)(x, torch.tensor(False))
@ -1584,8 +1584,8 @@ def forward(self, arg0_1):
x = torch.randn(4) x = torch.randn(4)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch._dynamo.exc.UncapturedHigherOrderOpError, torch._dynamo.exc.CondOpArgsMismatchError,
"Cond doesn't work unless it is captured completely with torch.compile", "Expected to return same number of outputs but got:",
): ):
make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) make_fx(f, tracing_mode="fake")(x, torch.tensor(False))

View File

@ -2346,6 +2346,7 @@ known_failing_tests = {
"test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors) "test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
"test_unpack_hooks_exec_count", # pack/unpack saved tensor hooks firing more than once "test_unpack_hooks_exec_count", # pack/unpack saved tensor hooks firing more than once
"test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads "test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads
"test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938
} }
if not HAS_CUDA: if not HAS_CUDA:

View File

@ -2103,6 +2103,7 @@ class CheckFunctionManager:
guard.create(builder) guard.create(builder)
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn) self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
# Keep track of weak references of objects with ID_MATCH guard. This # Keep track of weak references of objects with ID_MATCH guard. This
# info is stored alongside optimized_code and check_fn and is used to # info is stored alongside optimized_code and check_fn and is used to
# limit the number of cache entries with same ID_MATCH'd object. # limit the number of cache entries with same ID_MATCH'd object.
@ -2123,6 +2124,18 @@ class CheckFunctionManager:
self.guard_manager.id_matched_objs = builder.id_matched_objs self.guard_manager.id_matched_objs = builder.id_matched_objs
self.check_fn = self.guard_manager self.check_fn = self.guard_manager
# Check that the guard returns True. False means that we will always
# recompile.
# TODO(anijain2305, ydwu4) - Skipping export because of following test
# python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs
if not output_graph.export:
if not self.guard_manager.check(output_graph.local_scope):
reasons = get_guard_fail_reason_helper(
self.guard_manager, # type: ignore[arg-type]
output_graph.local_scope,
)
raise AssertionError(f"Guard check failed: {reasons}")
# NB - We have to very careful of cleaning up here. Because of the # NB - We have to very careful of cleaning up here. Because of the
# invalidate function, we can create a weakref finalizer that keeps # invalidate function, we can create a weakref finalizer that keeps
# `self` alive for very long. Sometimes by mistake, we can run # `self` alive for very long. Sometimes by mistake, we can run
@ -2456,9 +2469,8 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
return [f"Duplicate tensors found: {reason}"] return [f"Duplicate tensors found: {reason}"]
def get_guard_fail_reason( def get_guard_fail_reason_helper(
guard_fn: GuardFn, guard_fn: GuardFn,
code: types.CodeType,
f_locals: Dict[str, object], f_locals: Dict[str, object],
) -> str: ) -> str:
""" """
@ -2525,6 +2537,15 @@ def get_guard_fail_reason(
break break
reason_str = "\n".join(reasons) reason_str = "\n".join(reasons)
return reason_str
def get_guard_fail_reason(
guard_fn: GuardFn,
code: types.CodeType,
f_locals: Dict[str, object],
) -> str:
reason_str = get_guard_fail_reason_helper(guard_fn, f_locals)
guard_failures[orig_code_map[code]].append(reason_str) guard_failures[orig_code_map[code]].append(reason_str)
try: try:

View File

@ -587,9 +587,16 @@ def is_wrapper_or_member_descriptor(value):
return isinstance( return isinstance(
value, value,
( (
types.MethodWrapperType, # set up by PyGetSetDef
types.GetSetDescriptorType,
# set by PyMethodDef, e.g. list.append
types.MethodDescriptorType,
# slots - list.__add__
types.WrapperDescriptorType, types.WrapperDescriptorType,
# set up by PyMemberDef
types.MemberDescriptorType, types.MemberDescriptorType,
# wrapper over C functions
types.MethodWrapperType,
), ),
) )

View File

@ -2538,6 +2538,22 @@ class SourcelessBuilder:
handlers[immutable_list] = handlers[list] handlers[immutable_list] = handlers[list]
handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value) handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value)
handlers[
torch.distributions.constraints._Real
] = lambda tx, value: UserDefinedObjectVariable(
value, mutable_local=MutableLocal()
)
handlers[
torch.distributions.constraints._Interval
] = lambda tx, value: UserDefinedObjectVariable(
value, mutable_local=MutableLocal()
)
handlers[
torch.distributions.constraints.Constraint
] = lambda tx, value: UserDefinedObjectVariable(
value, mutable_local=MutableLocal()
)
def passthrough(tx, value): def passthrough(tx, value):
return value return value

View File

@ -4,11 +4,9 @@ import collections
import contextlib import contextlib
import enum import enum
import functools import functools
import importlib
import inspect import inspect
import itertools import itertools
import random import random
import re
import sys import sys
import threading import threading
import types import types
@ -45,13 +43,13 @@ from ..source import (
WeakRefCallSource, WeakRefCallSource,
) )
from ..utils import ( from ..utils import (
all_hook_names,
build_checkpoint_variable, build_checkpoint_variable,
check_constant_args, check_constant_args,
get_custom_getattr, get_custom_getattr,
has_torch_function, has_torch_function,
is_namedtuple_cls, is_namedtuple_cls,
is_utils_checkpoint, is_utils_checkpoint,
is_wrapper_or_member_descriptor,
istype, istype,
namedtuple_fields, namedtuple_fields,
object_has_getattribute, object_has_getattribute,
@ -829,7 +827,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
def var_getattr(self, tx, name): def var_getattr(self, tx, name):
from .. import trace_rules from .. import trace_rules
from . import ConstantVariable from . import ConstantVariable
from .builder import VariableBuilder
value = self.value value = self.value
source = AttrSource(self.source, name) if self.source else None source = AttrSource(self.source, name) if self.source else None
@ -842,6 +839,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
options = {"source": source} options = {"source": source}
return variables.GetAttrVariable(self, name, **options) return variables.GetAttrVariable(self, name, **options)
# TODO(anijain2305) - Investigate if we need specialization for more
# dunder attrs. inspect.getattr_static does not return correct value for
# them.
if name == "__class__":
options = {"source": source}
return UserDefinedClassVariable(type(self.value), **options)
try: try:
subobj = self._getattr_static(name) subobj = self._getattr_static(name)
except AttributeError: except AttributeError:
@ -891,11 +895,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return variables.UserMethodVariable( return variables.UserMethodVariable(
subobj.fget, self, source=source subobj.fget, self, source=source
).call_function(tx, [], {}) ).call_function(tx, [], {})
elif isinstance(subobj, torch.distributions.utils.lazy_property):
subobj_var = UserDefinedObjectVariable(subobj, source=source)
return variables.UserMethodVariable(
subobj.__get__.__func__, subobj_var, source=source
).call_function(tx, [self], {})
elif isinstance(subobj, staticmethod): elif isinstance(subobj, staticmethod):
func = subobj.__get__(self.value) func = subobj.__get__(self.value)
if source is not None: if source is not None:
@ -906,6 +905,25 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return variables.UserMethodVariable( return variables.UserMethodVariable(
subobj.__func__, self.var_getattr(tx, "__class__"), source=source subobj.__func__, self.var_getattr(tx, "__class__"), source=source
) )
elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
subobj.__get__
):
# Attribute has a __get__ method. Create a user defined object vt
# for the subobj, and then trace the __get__ method.
descriptor_var = UserDefinedObjectVariable(subobj, source=source)
get_source = self.source
if self.source:
get_source = AttrSource(self.source, "__get__")
# The arguments of the __get__ function are (self, instance, owner)
# self - descriptor_var
# instance - instance of the class, represented by self here
# owner - class object
owner_var = UserDefinedClassVariable(type(self.value))
return variables.UserMethodVariable(
subobj.__get__.__func__, descriptor_var, source=get_source
).call_function(tx, [descriptor_var, self, owner_var], {})
elif isinstance(subobj, types.FunctionType) or ( elif isinstance(subobj, types.FunctionType) or (
isinstance(subobj, types.MethodType) isinstance(subobj, types.MethodType)
and isinstance(self.value, torch.nn.Module) and isinstance(self.value, torch.nn.Module)
@ -943,77 +961,86 @@ class UserDefinedObjectVariable(UserDefinedVariable):
else: else:
return trace_rules.lookup(func)(func) return trace_rules.lookup(func)(func)
if ( if subobj is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor(subobj):
name in getattr(value, "__dict__", {})
or ConstantVariable.is_literal(subobj)
or isinstance(
subobj,
(
torch.Tensor,
torch.nn.Module,
re.Pattern,
),
)
):
if source: if source:
return variables.LazyVariableTracker.create(subobj, source) return variables.LazyVariableTracker.create(subobj, source)
elif ConstantVariable.is_literal(subobj): else:
return ConstantVariable.create(subobj)
elif (
type(subobj) == torch.utils._pytree.TreeSpec
or type(subobj) == torch.utils._pytree.LeafSpec
or type(value) == torch.utils._pytree.TreeSpec
):
from .builder import SourcelessBuilder from .builder import SourcelessBuilder
return SourcelessBuilder.create(tx, subobj) return SourcelessBuilder.create(tx, subobj)
if (
subobj is not NO_SUCH_SUBOBJ
and name not in getattr(value, "__dict__", {})
and (
type(value).__module__.startswith("torch.")
or isinstance(subobj, re.Pattern)
)
and "torch.optim" not in type(value).__module__
and not callable(value)
and not isinstance(subobj, types.MethodDescriptorType)
):
if not source:
assert getattr(
importlib.import_module(type(value).__module__),
type(value).__name__,
) is type(value)
source = AttrSource(
AttrSource(
tx.import_source(type(value).__module__), type(value).__name__
),
name,
)
return VariableBuilder(tx, source)(subobj)
options = {"source": source} options = {"source": source}
if isinstance(
subobj,
(
torch.distributions.constraints._Interval,
torch.distributions.constraints._Real,
torch.distributions.constraints.Constraint,
),
):
return UserDefinedObjectVariable(subobj, **options)
elif isinstance(self.value, torch.nn.Module) and name in all_hook_names:
assert isinstance(subobj, collections.OrderedDict)
if not subobj:
return variables.ConstDictVariable(
subobj, collections.OrderedDict, **options
)
if name == "__class__":
return UserDefinedClassVariable(type(self.value), **options)
return variables.GetAttrVariable(self, name, **options) return variables.GetAttrVariable(self, name, **options)
# if (
# name in getattr(value, "__dict__", {})
# or ConstantVariable.is_literal(subobj)
# or isinstance(
# subobj,
# (
# torch.Tensor,
# torch.nn.Module,
# re.Pattern,
# ),
# )
# ):
# if source:
# return variables.LazyVariableTracker.create(subobj, source)
# elif ConstantVariable.is_literal(subobj):
# return ConstantVariable.create(subobj)
# elif (
# type(subobj) == torch.utils._pytree.TreeSpec
# or type(subobj) == torch.utils._pytree.LeafSpec
# or type(value) == torch.utils._pytree.TreeSpec
# ):
# from .builder import SourcelessBuilder
# return SourcelessBuilder.create(tx, subobj)
# if (
# subobj is not NO_SUCH_SUBOBJ
# and name not in getattr(value, "__dict__", {})
# and (
# type(value).__module__.startswith("torch.")
# or isinstance(subobj, re.Pattern)
# )
# and "torch.optim" not in type(value).__module__
# and not callable(value)
# and not isinstance(subobj, types.MethodDescriptorType)
# ):
# if not source:
# assert getattr(
# importlib.import_module(type(value).__module__),
# type(value).__name__,
# ) is type(value)
# source = AttrSource(
# AttrSource(
# tx.import_source(type(value).__module__), type(value).__name__
# ),
# name,
# )
# return VariableBuilder(tx, source)(subobj)
# options = {"source": source}
# if isinstance(
# subobj,
# (
# torch.distributions.constraints._Interval,
# torch.distributions.constraints._Real,
# torch.distributions.constraints.Constraint,
# ),
# ):
# return UserDefinedObjectVariable(subobj, **options)
# elif isinstance(self.value, torch.nn.Module) and name in all_hook_names:
# assert isinstance(subobj, collections.OrderedDict)
# if not subobj:
# return variables.ConstDictVariable(
# subobj, collections.OrderedDict, **options
# )
# if name == "__class__":
# return UserDefinedClassVariable(type(self.value), **options)
# return variables.GetAttrVariable(self, name, **options)
def call_hasattr(self, tx, name: str) -> "VariableTracker": def call_hasattr(self, tx, name: str) -> "VariableTracker":
if tx.output.side_effects.is_attribute_mutation(self): if tx.output.side_effects.is_attribute_mutation(self):
try: try: