mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods (#137119)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137119 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 ghstack dependencies: #137114, #137115, #137116, #137117, #137120, #137227
This commit is contained in:
committed by
PyTorch MergeBot
parent
0a304d9048
commit
d5785d4295
@ -701,7 +701,7 @@ class CompileTest(TestCase):
|
||||
FileCheck()
|
||||
.check(
|
||||
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced"
|
||||
".default([arg0_1, arg1_1, arg2_1, arg3_1]"
|
||||
".default([arg3_1, arg2_1, arg1_1, arg0_1]"
|
||||
)
|
||||
.check("buf1 = buf0[0]")
|
||||
.check("buf2 = buf0[1]")
|
||||
@ -717,8 +717,8 @@ class CompileTest(TestCase):
|
||||
)
|
||||
|
||||
# Test aoti
|
||||
out = AOTIRunnerUtil.run("cuda", func, (args,))
|
||||
torch.cuda.synchronize()
|
||||
# out = AOTIRunnerUtil.run("cuda", func, (args,))
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
|
||||
@ -646,10 +646,10 @@ print("arf")
|
||||
self.assertExpectedInline(
|
||||
munge_shape_guards(record.getMessage()),
|
||||
"""\
|
||||
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: L['z'].size()[0] == L['y'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
|
||||
+- LAMBDA_GUARD: Eq(Mod(2*L['y'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
|
||||
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
|
||||
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
|
||||
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
|
||||
)
|
||||
|
||||
@make_logging_test(guards=True)
|
||||
|
||||
@ -672,7 +672,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
wrapped2 = y.as_subclass(SigmoidToExpSubclass)
|
||||
|
||||
def fn(w):
|
||||
return w.sigmoid()
|
||||
return w.exp()
|
||||
|
||||
fn_opt = compile_full_eager(fn)
|
||||
|
||||
@ -683,6 +683,38 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(res_exp, res_act)
|
||||
self.assertEqual(res_exp, res_exp2)
|
||||
|
||||
def test_torch_function_call_on_method_arg(self):
|
||||
class LocalSubclass(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if func == torch._C.TensorBase.add_:
|
||||
func = torch._C.TensorBase.sub_
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
def sigmoid(self):
|
||||
return None
|
||||
|
||||
x = torch.ones(2, 2)
|
||||
y = torch.ones(2, 2)
|
||||
z = torch.ones(2, 2)
|
||||
wrapped = y.as_subclass(LocalSubclass)
|
||||
wrapped2 = z.as_subclass(LocalSubclass)
|
||||
|
||||
def fn(a, w):
|
||||
a.add_(w)
|
||||
return a
|
||||
|
||||
fn_opt = torch.compile(fn)
|
||||
|
||||
with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}):
|
||||
res_exp = fn(x, wrapped)
|
||||
res_act = fn_opt(y, wrapped2)
|
||||
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_user_overidden_method_unsupported(self):
|
||||
class LocalSubclass(torch.Tensor):
|
||||
@classmethod
|
||||
|
||||
@ -180,12 +180,10 @@ class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
|
||||
self.assertExpectedInline(
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
|
||||
"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
# No stacktrace found for following nodes
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \
|
||||
arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
return ()""",
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
return ()""", # noqa: B950
|
||||
)
|
||||
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
@ -239,7 +237,7 @@ arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None
|
||||
getitem_4: "f32[3][1]cpu" = foo_default[0]
|
||||
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
|
||||
return (getitem_4, getitem_5)""", # noqa: B950
|
||||
@ -402,9 +400,9 @@ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None
|
||||
foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None
|
||||
copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None
|
||||
copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
@ -414,9 +412,9 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = foo_default = None
|
||||
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
@ -503,12 +501,11 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
|
||||
post_grad_graphs,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
|
||||
foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None
|
||||
foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = None
|
||||
getitem_4: "f32[3][1]cpu" = foo_default[0]
|
||||
getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None
|
||||
|
||||
copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None
|
||||
copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None
|
||||
return (getitem_4, getitem_5)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
|
||||
@ -2912,6 +2912,9 @@ def get_tensor_method():
|
||||
method, (types.MethodDescriptorType, types.WrapperDescriptorType)
|
||||
):
|
||||
s.add(method)
|
||||
|
||||
# mlazos: this is a function which we handle specially in TensorVariable
|
||||
s.add(torch.Tensor.__contains__) # type: ignore[arg-type]
|
||||
return frozenset(s)
|
||||
|
||||
|
||||
|
||||
@ -2912,18 +2912,28 @@ def is_torch_function_object(value):
|
||||
|
||||
|
||||
def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool:
|
||||
from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable
|
||||
from torch._dynamo.variables import UserDefinedObjectVariable
|
||||
from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
|
||||
|
||||
if isinstance(vt, TensorWithTFOverrideVariable):
|
||||
return True
|
||||
# Note on lazy vars: The value will either be realized or not throughout the course of execution
|
||||
# if the value has a torch function, it will eventually be realized so we can realize it here
|
||||
# if the value does not have a torch function, it may or may not be realized
|
||||
# if it is realized it will be used and guards will be installed properly
|
||||
# if it is not used, guards won't be installed, and it doesn't matter
|
||||
# if the value has a torch function or not, so we should *not* realize it.
|
||||
# NB: We technically know that if is_realized is False, LazyVariableTracker has the peek_value method
|
||||
# but mypy does not unfortunately
|
||||
if vt.is_realized() or (
|
||||
hasattr(vt, "peek_value") and hasattr(vt.peek_value(), "__torch_function__")
|
||||
):
|
||||
if isinstance(vt, TensorWithTFOverrideVariable):
|
||||
return True
|
||||
|
||||
if isinstance(vt, LazyVariableTracker):
|
||||
LazyVariableTracker.realize(vt)
|
||||
return isinstance(vt, UserDefinedObjectVariable) and hasattr(
|
||||
vt.value, "__torch_function__"
|
||||
)
|
||||
|
||||
return isinstance(vt, UserDefinedObjectVariable) and hasattr(
|
||||
vt.value, "__torch_function__"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# see note [Tensor Fakification and Symbol Caching]
|
||||
|
||||
@ -80,6 +80,14 @@ class LazyVariableTracker(VariableTracker):
|
||||
self.realize()
|
||||
return VariableTracker.clone(self.unwrap(), **kwargs)
|
||||
|
||||
def peek_type(self) -> type[Any]:
|
||||
assert not self.is_realized()
|
||||
return type(self._cache.value)
|
||||
|
||||
def peek_value(self) -> Any:
|
||||
assert not self.is_realized()
|
||||
return self._cache.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.is_realized():
|
||||
return self.unwrap().__str__()
|
||||
|
||||
@ -510,9 +510,37 @@ class TensorVariable(VariableTracker):
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
|
||||
|
||||
if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
|
||||
unimplemented(f"Illegal method invocation {name} in strict mode")
|
||||
|
||||
# Only override builtin tensor methods
|
||||
# The user can manually add override handling
|
||||
# with a decorator for other methods (e.g. a dispatch subclass with other methods)
|
||||
has_torch_function_override = False
|
||||
try:
|
||||
inspect.getattr_static(torch.Tensor, name)
|
||||
has_torch_function_override = True
|
||||
except AttributeError:
|
||||
has_torch_function_override = False
|
||||
|
||||
if (
|
||||
can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs)
|
||||
and has_torch_function_override
|
||||
):
|
||||
if self.source:
|
||||
func_var = VariableBuilder(
|
||||
tx, AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
)(inspect.getattr_static(torch.Tensor, name))
|
||||
else:
|
||||
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
|
||||
|
||||
return dispatch_torch_function(
|
||||
tx, func_var, tuple([self] + list(args)), kwargs
|
||||
)
|
||||
|
||||
"""
|
||||
Dispatch to a method-specific handler defined below. If the
|
||||
handler returns None (or doesn't exist) we put the method call
|
||||
|
||||
@ -1183,7 +1183,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
inspect.ismethoddescriptor(self.get_function())
|
||||
and hasattr(self.get_function(), "__objclass__")
|
||||
and self.get_function().__objclass__ == torch._C.TensorBase
|
||||
)
|
||||
) or self.get_function() is torch.Tensor.__contains__
|
||||
|
||||
def torch_function_override_enabled(self, tx, args, kwargs):
|
||||
return (
|
||||
|
||||
@ -442,7 +442,6 @@ def _flatten_vts(vts):
|
||||
from collections import deque
|
||||
|
||||
from .dicts import ConstDictVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import ListVariable
|
||||
|
||||
vts = deque(vts)
|
||||
@ -450,13 +449,17 @@ def _flatten_vts(vts):
|
||||
|
||||
while vts:
|
||||
vt = vts.pop()
|
||||
LazyVariableTracker.realize_all(vt)
|
||||
if isinstance(vt, ListVariable):
|
||||
vts.extend(vt.items)
|
||||
elif isinstance(vt, ConstDictVariable):
|
||||
vts.extend(vt.items.values())
|
||||
else:
|
||||
output.append(vt)
|
||||
|
||||
if not vt.is_realized() and vt.peek_type() in (dict, list, tuple):
|
||||
vt.realize()
|
||||
|
||||
if vt.is_realized():
|
||||
if isinstance(vt, ListVariable):
|
||||
vts.extend(vt.items)
|
||||
elif isinstance(vt, ConstDictVariable):
|
||||
vts.extend(vt.items.values())
|
||||
|
||||
output.append(vt)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user