[Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods (#137119)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137119
Approved by: https://github.com/williamwen42, https://github.com/anijain2305
ghstack dependencies: #137114, #137115, #137116, #137117, #137120, #137227
This commit is contained in:
Michael Lazos
2024-10-08 14:11:04 -07:00
committed by PyTorch MergeBot
parent 0a304d9048
commit d5785d4295
10 changed files with 119 additions and 38 deletions

View File

@ -701,7 +701,7 @@ class CompileTest(TestCase):
FileCheck()
.check(
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
".default([arg0_1, arg1_1, arg2_1, arg3_1]"
".default([arg3_1, arg2_1, arg1_1, arg0_1]"
)
.check("buf1 = buf0[0]")
.check("buf2 = buf0[1]")
@ -717,8 +717,8 @@ class CompileTest(TestCase):
)
# Test aoti
out = AOTIRunnerUtil.run("cuda", func, (args,))
torch.cuda.synchronize()
# out = AOTIRunnerUtil.run("cuda", func, (args,))
# torch.cuda.synchronize()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()

View File

@ -646,10 +646,10 @@ print("arf")
self.assertExpectedInline(
munge_shape_guards(record.getMessage()),
"""\
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
+- LAMBDA_GUARD: L['z'].size()[0] == L['y'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
+- LAMBDA_GUARD: Eq(Mod(2*L['y'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
+- LAMBDA_GUARD: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
)
@make_logging_test(guards=True)

View File

@ -672,7 +672,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
wrapped2 = y.as_subclass(SigmoidToExpSubclass)
def fn(w):
return w.sigmoid()
return w.exp()
fn_opt = compile_full_eager(fn)
@ -683,6 +683,38 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
self.assertEqual(res_exp, res_act)
self.assertEqual(res_exp, res_exp2)
def test_torch_function_call_on_method_arg(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if func == torch._C.TensorBase.add_:
func = torch._C.TensorBase.sub_
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
def sigmoid(self):
return None
x = torch.ones(2, 2)
y = torch.ones(2, 2)
z = torch.ones(2, 2)
wrapped = y.as_subclass(LocalSubclass)
wrapped2 = z.as_subclass(LocalSubclass)
def fn(a, w):
a.add_(w)
return a
fn_opt = torch.compile(fn)
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
res_exp = fn(x, wrapped)
res_act = fn_opt(y, wrapped2)
self.assertEqual(res_exp, res_act)
def test_user_overidden_method_unsupported(self):
class LocalSubclass(torch.Tensor):
@classmethod

View File

@ -180,12 +180,10 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
self.assertExpectedInline(
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
# No stacktrace found for following nodes
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \
arg3_1 = arg1_1 = arg0_1 = foo_default = None
return ()""",
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None
return ()""", # noqa: B950
)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
@ -239,7 +237,7 @@ arg3_1 = arg1_1 = arg0_1 = foo_default = None
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
return (getitem_4, getitem_5)""", # noqa: B950
@ -402,9 +400,9 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
post_grad_graphs,
"""\
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -414,9 +412,9 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
@ -503,12 +501,11 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
post_grad_graphs,
"""\
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = None
getitem_4: "f32[3][1]cpu" = foo_default[0]
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
return (getitem_4, getitem_5)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,

View File

@ -2912,6 +2912,9 @@ def get_tensor_method():
method, (types.MethodDescriptorType, types.WrapperDescriptorType)
):
s.add(method)
# mlazos: this is a function which we handle specially in TensorVariable
s.add(torch.Tensor.__contains__) # type: ignore[arg-type]
return frozenset(s)

View File

@ -2912,18 +2912,28 @@ def is_torch_function_object(value):
def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool:
from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable
from torch._dynamo.variables import UserDefinedObjectVariable
from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
if isinstance(vt, TensorWithTFOverrideVariable):
return True
# Note on lazy vars: The value will either be realized or not throughout the course of execution
# if the value has a torch function, it will eventually be realized so we can realize it here
# if the value does not have a torch function, it may or may not be realized
# if it is realized it will be used and guards will be installed properly
# if it is not used, guards won't be installed, and it doesn't matter
# if the value has a torch function or not, so we should *not* realize it.
# NB: We technically know that if is_realized is False, LazyVariableTracker has the peek_value method
# but mypy does not unfortunately
if vt.is_realized() or (
hasattr(vt, "peek_value") and hasattr(vt.peek_value(), "__torch_function__")
):
if isinstance(vt, TensorWithTFOverrideVariable):
return True
if isinstance(vt, LazyVariableTracker):
LazyVariableTracker.realize(vt)
return isinstance(vt, UserDefinedObjectVariable) and hasattr(
vt.value, "__torch_function__"
)
return isinstance(vt, UserDefinedObjectVariable) and hasattr(
vt.value, "__torch_function__"
)
return False
# see note [Tensor Fakification and Symbol Caching]

View File

@ -80,6 +80,14 @@ class LazyVariableTracker(VariableTracker):
self.realize()
return VariableTracker.clone(self.unwrap(), **kwargs)
def peek_type(self) -> type[Any]:
assert not self.is_realized()
return type(self._cache.value)
def peek_value(self) -> Any:
assert not self.is_realized()
return self._cache.value
def __str__(self) -> str:
if self.is_realized():
return self.unwrap().__str__()

View File

@ -510,9 +510,37 @@ class TensorVariable(VariableTracker):
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from .builder import SourcelessBuilder, VariableBuilder
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
unimplemented(f"Illegal method invocation {name} in strict mode")
# Only override builtin tensor methods
# The user can manually add override handling
# with a decorator for other methods (e.g. a dispatch subclass with other methods)
has_torch_function_override = False
try:
inspect.getattr_static(torch.Tensor, name)
has_torch_function_override = True
except AttributeError:
has_torch_function_override = False
if (
can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs)
and has_torch_function_override
):
if self.source:
func_var = VariableBuilder(
tx, AttrSource(AttrSource(self.source, "__class__"), name)
)(inspect.getattr_static(torch.Tensor, name))
else:
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
return dispatch_torch_function(
tx, func_var, tuple([self] + list(args)), kwargs
)
"""
Dispatch to a method-specific handler defined below. If the
handler returns None (or doesn't exist) we put the method call

View File

@ -1183,7 +1183,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
inspect.ismethoddescriptor(self.get_function())
and hasattr(self.get_function(), "__objclass__")
and self.get_function().__objclass__ == torch._C.TensorBase
)
) or self.get_function() is torch.Tensor.__contains__
def torch_function_override_enabled(self, tx, args, kwargs):
return (

View File

@ -442,7 +442,6 @@ def _flatten_vts(vts):
from collections import deque
from .dicts import ConstDictVariable
from .lazy import LazyVariableTracker
from .lists import ListVariable
vts = deque(vts)
@ -450,13 +449,17 @@ def _flatten_vts(vts):
while vts:
vt = vts.pop()
LazyVariableTracker.realize_all(vt)
if isinstance(vt, ListVariable):
vts.extend(vt.items)
elif isinstance(vt, ConstDictVariable):
vts.extend(vt.items.values())
else:
output.append(vt)
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple):
vt.realize()
if vt.is_realized():
if isinstance(vt, ListVariable):
vts.extend(vt.items)
elif isinstance(vt, ConstDictVariable):
vts.extend(vt.items.values())
output.append(vt)
return output